""" Get all the tables and figures in the descriptives section.

"""
import stata_setup
stata_setup.config("/Applications/Stata/", "se")

from pystata import stata

import os
import copy
import pandas as pd
import json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc
try:
    from src.interpolation.splines import UCGrid
except:
    from ..interpolation.splines import UCGrid
import sys
sys.path.append('./')

from src.models_new import visualizations
from src.run_scripts import utils



#%% CONFIGURATION -----------------------------------------------------------------------
# When running this file usually also add argument which is the path for the table
try:
    overleaf_path = sys.argv[1]
    root_path = sys.argv[2]
except:
    overleaf_path = "/Users/nicholasvreugdenhil/Dropbox (Personal)/Apps/Overleaf/bbm_test/draft/reports/revision/"
    root_path = './'

# Make the settings for the graph
plt.rcParams['text.latex.preamble'] = r"\usepackage{lmodern}"  # Use latex
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['lmodern'] + plt.rcParams['font.serif']
rc('text', usetex=True)
plt.rc('font', size=9)  # controls default text sizes
plt.rc('axes', titlesize=9)  # fontsize of the axes title
plt.rc('axes', labelsize=9)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=9)  # fontsize of the tick labels
plt.rc('ytick', labelsize=9)  # fontsize of the tick labels
plt.rc('legend', fontsize=8)  # legend fontsize
plt.rc('figure', titlesize=10)  # figure title size

#%% IMPORT DATA -------------------------------------------------------------------------
include_ci = True

if include_ci:
    bootstrap_by_seed = dict()
    for i in range(200):
        bootstrap_by_seed[i] = pd.read_csv(
            f'./models/smm/bootstrap_output/params_smm_with_diff_new{i}.csv',
            index_col=[0]
        ).squeeze("columns")
    df = pd.DataFrame(bootstrap_by_seed)
    ci_lower = df.T.quantile(0.025).drop(['rho_0', 'rho_1', 'rho_2', 'rho_3'])
    ci_upper = df.T.quantile(0.975).drop(['rho_0', 'rho_1', 'rho_2', 'rho_3'])

    # Add in the rho parameters...
    rho_lower = pd.read_csv('./models/smm_input/rho_lower.csv', index_col=[0]).squeeze("columns")
    rho_upper = pd.read_csv('./models/smm_input/rho_upper.csv', index_col=[0]).squeeze("columns")

    ci_lower = pd.concat([ci_lower, rho_lower])
    ci_upper = pd.concat([ci_upper, rho_upper])

else:
    ci_lower = None
    ci_upper = None

#%% IMPORT DATA -------------------------------------------------------------------------
data, delta, rho, c, weights = utils.read_in_data(
    path_moments_data='./models/smm_input/moments_empirical.csv',
    path_n_rigs=f"./models/first_stage/n_rigs",
    path_surplus_components="./models/surplus/surplus_components",
    path_surplus_grid='./models/surplus/surplus_grid_2_low_month.npy',
    path_df_state='./data_py/processed/states',
    path_delta='./models/smm_input/delta.csv',
    path_entry_cost='./models/smm_input/entry_cost.csv',
    path_rho='./models/smm_input/rho.csv',
    path_df_contracts='./data_py/processed/contracts_final.csv',
    path_price_match_values='./models/price_match/price_match_values',
    path_coefs_data='./models/smm_input/coefs_data',
    path_prob_match_predict_contracts='./models/robustness/prob_match_predict_contracts',
    path_prob_match_predict='./models/robustness/prob_match_predict',
    time_period='month',
    p_exit=1.0,
    use_myopic=True
)

params = pd.read_csv('./models/smm/params_smm_with_diff_new.csv', index_col=[0]).squeeze("columns")

df_contracts = pd.read_csv('./data_py/processed/contracts_final.csv',
                           index_col=[0], parse_dates=['fixture_date', 'start'])
df_contracts_no_deepening = pd.read_csv('./data_py/processed/contracts_final_no_deepening.csv',
                           index_col=[0], parse_dates=['fixture_date', 'start'])

df_state = pd.read_csv('./data_py/processed/states_month.csv')
df_state['date'] = pd.to_datetime(df_state['date'])
df_state['boom'] = (df_state['gas'] >= df_state['gas'].mean())

#%% GET VALUE FUNCTIONS -----------------------------------------------------------------
t = 'month'
r = pd.read_csv(f'./models/first_stage/r_{t}.csv', index_col=[0])
const = pd.read_csv(f'./models/first_stage/const_{t}.csv', index_col=[0])
sigma = pd.read_csv(f'./models/first_stage/sigma_{t}.csv', index_col=[0])

search_grid_params_by_spec = dict()
search_grid_by_spec = dict()
search_value_by_spec = dict()
for spec in ['low', 'mid', 'high']:
    search_grid_params_by_spec[spec] = np.load(
        f'./models/value_search/search_grid_{spec}_{t}.npy')
    search_grid_by_spec[spec] = UCGrid(
        tuple(search_grid_params_by_spec[spec][0, :]),
        tuple(search_grid_params_by_spec[spec][1, :]),
        tuple(search_grid_params_by_spec[spec][2, :]),
        tuple(search_grid_params_by_spec[spec][3, :])
    )
    search_value_by_spec[spec] = np.load(
        f'./models/value_search/search_value_{spec}_{t}.npy')

#%% GET THE SUMMARY STATISTICS TABLE ----------------------------------------------------
visualizations.build_table_summary(
    df_contracts=df_contracts,
    df_state=df_state,
    overleaf_path=overleaf_path,
    input_path='./src/tex/table_summary.tex',
    output_path=overleaf_path + 'tables/table_summary.tex'
)

#%% GET THE PRICE DISPERSION TABLE ------------------------------------------------------
visualizations.build_table_dispersion(
    df=df_contracts,
    input_path='./src/tex/table_dispersion.tex',
    output_path=overleaf_path + 'tables/table_dispersion.tex',
    overleaf_path=overleaf_path
)

#%% GET THE SORTING TABLE ---------------------------------------------------------------
ci_by_reg = visualizations.build_table_sorting(
    df=df_contracts,
    input_path='./src/tex/table_sorting.tex',
    output_path=overleaf_path + 'tables/table_sorting.tex'
)

#%% GET THE SORTING (POSITIVE ASSORTATIVE MATCHING) GRAPH -------------------------------
visualizations.graph_matching_patterns(
    data=df_contracts,
    output_path=overleaf_path + "figures/figure_positive_assortive_matching.pdf"
)

#%% GET THE BOOM/BUST MATCHING PATTERNS -------------------------------------------------
import importlib
importlib.reload(visualizations)
visualizations.graph_boom_bust(
    ci=ci_by_reg[0],
    data=df_contracts,
    output_path=overleaf_path + "figures/figure_boom_bust.pdf"
)

#%% GET THE BOOM/BUST MATCHING PATTERNS V2 ----------------------------------------------
import importlib
importlib.reload(visualizations)
visualizations.graph_boom_bust_2(
    ci=ci_by_reg[0],
    data=df_contracts,
    output_path=overleaf_path + "figures/figure_boom_bust_2.pdf"
)

#%% GET THE BOOM/BUST MATCHING PATTERNS V3 ----------------------------------------------
import importlib
importlib.reload(visualizations)
visualizations.graph_boom_bust_composition(
    ci=ci_by_reg[0],
    data=df_contracts,
    output_path=overleaf_path + "figures/figure_boom_bust_composition.pdf"
)

#%% GET THE VALUE FUNCTION FIGURE -------------------------------------------------------
# 1. Get state evolution at each state in the data for the next 12 months
df_state_evolution_by_date = visualizations.get_value_function_data(
    df_state=df_state,
    const=const,
    r=r,
    search_grid_by_spec=search_grid_by_spec,
    search_value_by_spec=search_value_by_spec
)
df_state_evolution_by_date.to_csv('./models/benchmark/state_evolution_by_date.csv')

# 2. Get visualization
visualizations.value_search_graph(
    df=df_state_evolution_by_date,
    output_path=overleaf_path + "figures/figure_value_search.pdf"
)

#%% GET THE RESULTS FOR STATE TRANSITIONS -----------------------------------------------
transitions_params = pd.read_csv(
f'./models/first_stage/params_{t}.csv', index_col=0, header=None
).squeeze().to_dict()
transitions_errors = pd.read_csv(
    f'./models/first_stage/errors_{t}.csv', index_col=0, header=None
).squeeze().to_dict()
visualizations.build_table_transitions(
    transitions_params=transitions_params,
    transitions_errors=transitions_errors,
    input_path='./src/tex/table_state_transitions_small.tex',
    output_path=overleaf_path + 'tables/table_state_transitions_small.tex',
    overleaf_path=overleaf_path
)

#%% GET THE ACCEPTANCE SETS -------------------------------------------------------------
cutoffs_min = pd.read_csv('./models/benchmark/cutoffs_min.csv', index_col=[0])
cutoffs_max = pd.read_csv('./models/benchmark/cutoffs_max.csv', index_col=[0])

cutoffs_min.index = df_state['date']
cutoffs_max.index = df_state['date']

visualizations.graph_acceptance_sets(
    cutoffs_min,
    cutoffs_max,
    output_path=overleaf_path + "figures/figure_acceptance.pdf"
)

#%% GET THE FITTED MOMENTS --------------------------------------------------------------
moments_sim = pd.read_csv('./models/benchmark/moments_simulated.csv', index_col=[0]).squeeze("columns")
moments_data_non_coef = pd.read_csv('./models/smm_input/moments_empirical.csv', index_col=[0])
moments_data_coef = pd.read_csv('./models/smm_input/coefs_data_month.csv', index_col=[0])
mean_reneg = df_contracts['reneg'].mean()

import importlib
importlib.reload(visualizations)

moments_all = visualizations.smm_moments_table(
    moments_sim,
    moments_data_non_coef,
    moments_data_coef,
    mean_reneg,
    input_path='./src/tex/table_moments_detail.tex',
    output_path=overleaf_path + "tables/table_moments_detail.tex"
)

#%% GET THE MATCH PARAMETERS TABLE ------------------------------------------------------
params_match_list = [
    'm_0_low', 'm_1_low', 'm_0_mid', 'm_1_mid', 'm_0_high', 'm_1_high', 'm_2'
]
params_match = {i: params[i] for i in params_match_list}
if include_ci:
    ci_lower_match = {i: ci_lower[i] for i in params_match_list}
    ci_upper_match = {i: ci_upper[i] for i in params_match_list}
else:
    ci_lower_match = None
    ci_upper_match = None

visualizations.build_table_match(
    params_match=params_match,
    ci_lower_match=ci_lower_match,
    ci_upper_match=ci_upper_match,
    delta=delta[0],
    av_price=df_contracts['day_rate'].mean(),
    input_path='./src/tex/table_match.tex',
    output_path=overleaf_path + 'tables/table_match.tex',
    overleaf_path=overleaf_path
)

#%% GET THE PARAMETERS TABLE ------------------------------------------------------------
params_other_list = [
    'a_1_low', 'a_1_mid', 'a_1_high', 'd_0', 'd_1', 'mu_0', 'sigma_0', 'gamma',
    'gamma_negative', 'p_2', 'p_3', 'eta', 'c', 'rho_0', 'rho_1', 'rho_2', 'rho_3'
]
params_other = {i: params[i] for i in params_other_list}

params_match_list = [
    'm_0_low', 'm_1_low', 'm_0_mid', 'm_1_mid', 'm_0_high', 'm_1_high', 'm_2'
]
params_match = {i: params[i] for i in params_match_list}

if include_ci:
    ci_lower_match = {i: ci_lower[i] for i in params_other_list + params_match_list}
    ci_upper_match = {i: ci_upper[i] for i in params_other_list + params_match_list}
else:
    ci_lower_match = None
    ci_upper_match = None

visualizations.smm_to_tex(
    params_other=params_other,
    params_match=params_match,
    delta=delta[0],
    ci_lower_params_other=ci_lower_match,
    ci_upper_params_other=ci_upper_match,
    input_path='./src/tex/table_smm_parameters.tex',
    output_path=overleaf_path + 'tables/table_smm.tex',
    overleaf_path=overleaf_path
)

#%% GET THE COUNTERFACTUALS TABLE -------------------------------------------------------
import importlib
importlib.reload(utils)
importlib.reload(visualizations)

df_counterfactuals = pd.read_csv(
    './models/counterfactuals/counterfactual_results.csv',
    index_col=[0]
)
df_counterfactuals_fortnight = pd.read_csv(
    './models/counterfactuals/counterfactual_results_fortnight.csv',
    index_col=[0]
)
#df_counterfactuals_fortnight['benchmark_fortnight_total_value'] \
 #   = df_counterfactuals_fortnight['benchmark_fortnight_total_value'] / 2
#df_counterfactuals_fortnight['no_sorting_fortnight_total_value'] \
 #   = df_counterfactuals_fortnight['no_sorting_fortnight_total_value'] / 2

comparison_by_counter = {
    'benchmark': 'no_sorting',
    'benchmark_rig_target': 'no_sorting_rig_target',
    'intermediary': 'benchmark',
    'demand_smoothing': 'benchmark',
    'benchmark_fortnight': 'no_sorting_fortnight'
}
denom_by_counter = {
    'benchmark': 'initial',
    'benchmark_rig_target': 'initial',
    'intermediary': 'final',
    'demand_smoothing': 'initial',
    'benchmark_fortnight': 'initial'
}
data_fortnight, _, _, _, _ = utils.read_in_data(
    time_period='fortnight', use_myopic=True)

series_by_counter = dict()
results_by_counter = dict()
for counter in comparison_by_counter:

    if counter == 'benchmark_fortnight':
        gas_price = data_fortnight['gas_price']
        df_counter = df_counterfactuals_fortnight
    else:
        gas_price = data['gas_price']
        df_counter = df_counterfactuals

    series_by_counter[counter] = utils.get_counterfactual_decomposition(
        df_counterfactuals=df_counter,
        gas_price=copy.deepcopy(gas_price),
        counter=counter,
        comparison_by_counter=comparison_by_counter[counter]
    )

    with open('./src/models_new/settings.json', 'r') as file:
        settings = json.load(file)
    settings['benchmark_rig_target'] = settings['benchmark']

    if counter != 'benchmark_fortnight':
        visualizations.counterfactual_graph(
            series_by_counter[counter],
            comparison=denom_by_counter[counter],
            settings=settings[counter],
            output_path=overleaf_path + f'figures/figure_counterfactual_{counter}.pdf'
        )
        df_series = pd.DataFrame(series_by_counter[counter])
        df_series['boom'] = (data['state_data']['g'] > data['state_data']['g'].mean())

    elif counter == 'benchmark_fortnight':
        df_series = pd.DataFrame(series_by_counter[counter])
        df_series['boom'] = (
                data_fortnight['state_data']['g']
                > data_fortnight['state_data']['g'].mean()
        )

    results_by_counter[counter] = visualizations.counterfactual_table(
        series=df_series,
        comparison=denom_by_counter[counter],
        fig_name=counter,
        overleaf_path=overleaf_path
    )

    visualizations.counterfactual_absolute(
        series=df_series,
        comparison=denom_by_counter[counter],
        fig_name=counter,
        overleaf_path=overleaf_path
    )

# Get the robustness test of two-week periods rather than one month.
visualizations.robustness_two_weeks_table(
    vals_benchmark=round(30 * series_by_counter['benchmark']['final'].sum() / 1000, 2),
    vals_benchmark_fortnight=round(30 * series_by_counter['benchmark_fortnight']['final'].sum() / 1000, 2),
    overleaf_path=overleaf_path
)

# Get the robustness test with rig targeting
visualizations.robustness_rig_target_table(
    vals_benchmark=results_by_counter['benchmark'],
    vals_benchmark_rig_target=results_by_counter['benchmark_rig_target'],
    overleaf_path=overleaf_path
)

#%% GET THE DAYRATE GRAPH ---------------------------------------------------------------
# Run the dayrate construction
dayrate_path = f"{overleaf_path}figures/figure_dayrate.pdf"
stata.run(f"""
    import delimited "./data_py/processed/contracts_final.csv", clear
    
    replace day_rate = day_rate * 1000000
    drop date
    g date = date(fixture_date, "YMD")
    g month1 = date(month, "YMD")
    
    sort date
    
    twoway 	(lpolyci day_rate date, yaxis(1) color(blue) fcolor(ltblue) msize(vsmall) lpattern(shortdash)) ///
            (line gas month1, yaxis(2) color(gs5) lwidth(medium) ), ///
            ylabel(0(50000)150000, axis(1) nogrid) ylabel(0(5)15, axis(2) nogrid) xtitle("Year") ytitle("Rig Price (USD/day)") xlabel(14610(1000)18000,format(%td_CY)) ///
            ytitle(Rig Price, axis(1)) ytitle(Gas Price, axis(2)) ///
            legend(region(lcolor(white)) order(2 3) lab(3 "Natural Gas Price (USD)") lab(2 "Rig Price (USD/day)") ring(0) position(8)) ///
            graphregion(color(white)) scheme(s2color)
    
    *graph display, xsize(6.5) ysize(6)
    *display "`1'"
    graph export "{dayrate_path}", replace
"""
)
#os.system(
 #   "/Applications/Stata/StataSE.app/Contents/MacOS/StataSE "
  #  "-b "
   # "do "
    #f"'{root_path}src/models_new/graph_dayrate.do {dayrate_path}'"
    #" &"
#)

#%% GET THE VARIOUS NUMBERS -------------------------------------------------------------
import importlib
importlib.reload(visualizations)
df_wells_merged = pd.read_csv(
    './data_py/temp/06_merge_contracts_wells/wells_merged_no_impute.csv',
    index_col=[0],
    parse_dates=['spud_date', 'depth_date']
)
df_wells_merged = df_wells_merged.drop_duplicates('api')
visualizations.get_parameters(df_wells_merged, overleaf_path)

#%% GET PCT DEEPENING  ------------------------------------------------------------------
visualizations.get_deepening_number(df_contracts, df_contracts_no_deepening, overleaf_path)

#%% GET THE TARGETING NUMBER ------------------------------------------------------------
import importlib
importlib.reload(visualizations)
df_shares = pd.read_csv(
    './models/benchmark/shares_by_state.csv',
    index_col=[0]
)

visualizations.get_targeting_explanation(
    params, df_state, data['match_values_by_tau_spec'], data['surplus_grid'],
    df_shares, overleaf_path
)
#%% GET THE WITHIN-SAMPLE FIT -----------------------------------------------------------
import importlib
importlib.reload(visualizations)
visualizations.sorting_fit(moments_all, overleaf_path)

#%% GET THE MISMATCH TABLE --------------------------------------------------------------
import importlib
importlib.reload(visualizations)
visualizations.make_table_mismatch(df_contracts, df_state, params, overleaf_path)

#%% GET HHI -----------------------------------------------------------------------------
import importlib
importlib.reload(visualizations)
visualizations.get_data_section_parameters(df_contracts, overleaf_path)

#%% GET RELATIONSHIPS -------------------------------------------------------------------
visualizations.get_relationships(df_contracts, overleaf_path)

#%% GET UTILIZATION ---------------------------------------------------------------------
import importlib
importlib.reload(visualizations)
visualizations.get_utilization_table(df_state, overleaf_path)

#%% GET DATA CONSTRUCTION PARAMETERS ----------------------------------------------------
import importlib
importlib.reload(visualizations)

df_wells_with_mri = pd.read_csv(
    './data_py/temp/05_merge_with_map/wells_with_map.csv',
    index_col=[0],
    parse_dates=['spud_date', 'depth_date']
)

# How many init contracts?
df_contracts_init = pd.read_csv(
    './data_py/temp/04_clean_contracts/contracts_processed.csv',
    index_col=[0],
    parse_dates=['fixture_date']
)

# How many contracts with map?
df_contracts_with_map = pd.read_csv(
    './data_py/temp/05_merge_with_map/contracts_with_map.csv',
    index_col=[0],
    parse_dates=['fixture_date']
)

# How many merged
df_merged_wells_no_impute = pd.read_csv(
    './data_py/temp/06_merge_contracts_wells/wells_merged_no_impute.csv',
    index_col=[0],
    parse_dates=['spud_date', 'depth_date']
)
df_merged_contracts_no_impute = pd.read_csv(
    './data_py/temp/06_merge_contracts_wells/contracts_collapse_no_impute.csv',
    index_col=[0],
    parse_dates=['fixture_date']
)

# Get matched contracts (before imputation)
visualizations.get_data_construction(
    df_wells_with_mri, df_contracts_init,
    df_contracts_with_map,
    df_merged_wells_no_impute, df_merged_contracts_no_impute,
    df_contracts,
    overleaf_path)

#%% GET OUT OF SAMPLE FIT ---------------------------------------------------------------
import importlib
importlib.reload(visualizations)

df_state_long = pd.read_csv(
    f'./data_py/processed/states_month_long.csv',
    index_col=[0],
    parse_dates=['month']
)

data, delta, rho, c, weights = utils.read_in_data(use_myopic=True)

# Add in stuff for longer contracts
path_price_match_values = './models/price_match/price_match_values'

price_match_values_by_spec = dict()
for spec in ['low', 'mid', 'high']:
    price_match_values_by_spec[spec] = dict()
    for k in [0, 1, 2, 3]:
        with open(
                f"{path_price_match_values}_{spec}_{k}_month_long.json") as f:
            price_match_values_by_spec[spec][k] = np.array(json.load(f))
data['price_match_values_by_spec'] = price_match_values_by_spec

df_contracts_long = pd.read_csv(
    './data_py/processed/contracts_final_long.csv',
    index_col=[0],
    parse_dates=['fixture_date']
)

params = pd.read_csv('./models/smm/params_smm_with_diff_new.csv', index_col=[0]).squeeze("columns")
params = dict(params)

data['df_contracts'] = df_contracts

values_by_spec = dict()
for spec in ['low', 'mid', 'high']:
    values_by_spec[spec] = df_contracts_long.loc[df_contracts_long['spec'] == spec, 'value']
data['values_by_spec'] = values_by_spec

data['state_data'] = df_state_long
data['values_by_spec'] = values_by_spec
data['g_data'] = df_state_long[['date', 'month', 'g']]
data['g_data']['date'] = pd.to_datetime(data['g_data']['date'])
data['g_data'].loc[:, '2006'] = (pd.to_datetime(data['g_data']['date']).dt.year == 2006)

visualizations.get_out_of_sample(
    df_state=df_state_long,
    df_contracts=df_contracts_long,
    data=data,
    params=params,
    weights=weights,
    output_path=overleaf_path + "figures/figure_out_of_sample.pdf"
)
